{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.8 推論の実施の付録。学習と検証のDataLoaderに実施する\n",
    "\n",
    "本ファイルでは、学習させたSSDで物体検出を行います。\n",
    "\n",
    "\n",
    "VOC2012の訓練データセットと検証データセットに対して、学習済みSSDの推論を実施し、推論結果と正しい答えであるアノテーションデータの両方を表示させるファイルです。\n",
    "\n",
    "学習させたSSDモデルが正しいアノテーションデータとどれくらい近いのかなどを確認したいケースでは、こちらもご使用ください。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 事前準備\n",
    "\n",
    "- フォルダ「utils」に2.3～2.7までで実装した内容をまとめたssd_model.pyがあることを確認してください\n",
    "- 学習させた重みパラメータを用意"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2  # OpenCVライブラリ\n",
    "import matplotlib.pyplot as plt \n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 推論用の関数とクラスを作成する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ssd_predict(img_index, img_list, dataset, net=None, dataconfidence_level=0.5):\n",
    "    \"\"\"\n",
    "    SSDで予測させる関数。\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    img_index:  int\n",
    "        データセット内の予測対象画像のインデックス。\n",
    "    img_list: list\n",
    "        画像のファイルパスのリスト\n",
    "    dataset: PyTorchのDataset\n",
    "        画像のDataset\n",
    "    net: PyTorchのNetwork\n",
    "        学習させたSSDネットワーク\n",
    "    dataconfidence_level: float\n",
    "        予測で発見とする確信度の閾値\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    rgb_img, true_bbox, true_label_index, predict_bbox, pre_dict_label_index, scores\n",
    "    \"\"\"\n",
    "\n",
    "    # rgbの画像データを取得\n",
    "    image_file_path = img_list[img_index]\n",
    "    img = cv2.imread(image_file_path)  # [高さ][幅][色BGR]\n",
    "    height, width, channels = img.shape  # 画像のサイズを取得\n",
    "    rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "    # 正解のBBoxを取得\n",
    "    im, gt = dataset.__getitem__(img_index)\n",
    "    true_bbox = gt[:, 0:4] * [width, height, width, height]\n",
    "    true_label_index = gt[:, 4].astype(int)\n",
    "\n",
    "    # SSDで予測\n",
    "    net.eval()  # ネットワークを推論モードへ\n",
    "    x = im.unsqueeze(0)  # ミニバッチ化：torch.Size([1, 3, 300, 300])\n",
    "    detections = net(x)\n",
    "    # detectionsの形は、torch.Size([1, 21, 200, 5])  ※200はtop_kの値\n",
    "\n",
    "    # confidence_levelが基準以上を取り出す\n",
    "    predict_bbox = []\n",
    "    pre_dict_label_index = []\n",
    "    scores = []\n",
    "    detections = detections.cpu().detach().numpy()\n",
    "\n",
    "    # 条件以上の値を抽出\n",
    "    find_index = np.where(detections[:, 0:, :, 0] >= dataconfidence_level)\n",
    "    detections = detections[find_index]\n",
    "    for i in range(len(find_index[1])):  # 抽出した物体数分ループを回す\n",
    "        if (find_index[1][i]) > 0:  # 背景クラスでないもの\n",
    "            sc = detections[i][0]  # 確信度\n",
    "            bbox = detections[i][1:] * [width, height, width, height]\n",
    "            lable_ind = find_index[1][i]-1  # find_indexはミニバッチ数、クラス、topのtuple\n",
    "            # （注釈）\n",
    "            # 背景クラスが0なので1を引く\n",
    "\n",
    "            # 返り値のリストに追加\n",
    "            predict_bbox.append(bbox)\n",
    "            pre_dict_label_index.append(lable_ind)\n",
    "            scores.append(sc)\n",
    "\n",
    "    return rgb_img, true_bbox, true_label_index, predict_bbox, pre_dict_label_index, scores\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vis_bbox(rgb_img, bbox, label_index, scores, label_names):\n",
    "    \"\"\"\n",
    "    物体検出の予測結果を画像で表示させる関数。\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    rgb_img:rgbの画像\n",
    "        対象の画像データ\n",
    "    bbox: list\n",
    "        物体のBBoxのリスト\n",
    "    label_index: list\n",
    "        物体のラベルへのインデックス\n",
    "    scores: list\n",
    "        物体の確信度。\n",
    "    label_names: list\n",
    "        ラベル名の配列\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    なし。rgb_imgに物体検出結果が加わった画像が表示される。\n",
    "    \"\"\"\n",
    "\n",
    "    # 枠の色の設定\n",
    "    num_classes = len(label_names)  # クラス数（背景のぞく）\n",
    "    colors = plt.cm.hsv(np.linspace(0, 1, num_classes)).tolist()\n",
    "\n",
    "    # 画像の表示\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    plt.imshow(rgb_img)\n",
    "    currentAxis = plt.gca()\n",
    "\n",
    "    # BBox分のループ\n",
    "    for i, bb in enumerate(bbox):\n",
    "\n",
    "        # ラベル名\n",
    "        label_name = label_names[label_index[i]]\n",
    "        color = colors[label_index[i]]  # クラスごとに別の色の枠を与える\n",
    "\n",
    "        # 枠につけるラベル　例：person;0.72　\n",
    "        if scores is not None:\n",
    "            sc = scores[i]\n",
    "            display_txt = '%s: %.2f' % (label_name, sc)\n",
    "        else:\n",
    "            display_txt = '%s: ans' % (label_name)\n",
    "\n",
    "        # 枠の座標\n",
    "        xy = (bb[0], bb[1])\n",
    "        width = bb[2] - bb[0]\n",
    "        height = bb[3] - bb[1]\n",
    "\n",
    "        # 長方形を描画する\n",
    "        currentAxis.add_patch(plt.Rectangle(\n",
    "            xy, width, height, fill=False, edgecolor=color, linewidth=2))\n",
    "\n",
    "        # 長方形の枠の左上にラベルを描画する\n",
    "        currentAxis.text(xy[0], xy[1], display_txt, bbox={\n",
    "                         'facecolor': color, 'alpha': 0.5})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SSDPredictShow():\n",
    "    \"\"\"SSDでの予測と画像の表示をまとめて行うクラス\"\"\"\n",
    "\n",
    "    def __init__(self, img_list, dataset,  eval_categories, net=None, dataconfidence_level=0.6):\n",
    "        self.img_list = img_list\n",
    "        self.dataset = dataset\n",
    "        self.net = net\n",
    "        self.dataconfidence_level = dataconfidence_level\n",
    "        self.eval_categories = eval_categories\n",
    "\n",
    "    def show(self, img_index, predict_or_ans):\n",
    "        \"\"\"\n",
    "        物体検出の予測と表示をする関数。\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        img_index:  int\n",
    "            データセット内の予測対象画像のインデックス。\n",
    "        predict_or_ans: text\n",
    "            'precit'もしくは'ans'でBBoxの予測と正解のどちらを表示させるか指定する\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        なし。rgb_imgに物体検出結果が加わった画像が表示される。\n",
    "        \"\"\"\n",
    "        rgb_img, true_bbox, true_label_index, predict_bbox, pre_dict_label_index, scores = ssd_predict(img_index, self.img_list,\n",
    "                                                                 self.dataset,\n",
    "                                                                 self.net,\n",
    "                                                                 self.dataconfidence_level)\n",
    "\n",
    "        if predict_or_ans == \"predict\":\n",
    "            vis_bbox(rgb_img, bbox=predict_bbox, label_index=pre_dict_label_index,\n",
    "                     scores=scores, label_names=self.eval_categories)\n",
    "\n",
    "        elif predict_or_ans == \"ans\":\n",
    "            vis_bbox(rgb_img, bbox=true_bbox, label_index=true_label_index,\n",
    "                     scores=None, label_names=self.eval_categories)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 推論を実行する"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ssd_model import make_datapath_list, VOCDataset, DataTransform, Anno_xml2list, od_collate_fn\n",
    "\n",
    "\n",
    "# ファイルパスのリストを取得\n",
    "rootpath = \"./data/VOCdevkit/VOC2012/\"\n",
    "train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(\n",
    "    rootpath)\n",
    "\n",
    "# Datasetを作成\n",
    "voc_classes = ['aeroplane', 'bicycle', 'bird', 'boat',\n",
    "               'bottle', 'bus', 'car', 'cat', 'chair',\n",
    "               'cow', 'diningtable', 'dog', 'horse',\n",
    "               'motorbike', 'person', 'pottedplant',\n",
    "               'sheep', 'sofa', 'train', 'tvmonitor']\n",
    "color_mean = (104, 117, 123)  # (BGR)の色の平均値\n",
    "input_size = 300  # 画像のinputサイズを300×300にする\n",
    "\n",
    "train_dataset = VOCDataset(train_img_list, train_anno_list, phase=\"val\", transform=DataTransform(\n",
    "    input_size, color_mean), transform_anno=Anno_xml2list(voc_classes))\n",
    "\n",
    "val_dataset = VOCDataset(val_img_list, val_anno_list, phase=\"val\", transform=DataTransform(\n",
    "    input_size, color_mean), transform_anno=Anno_xml2list(voc_classes))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ssd_model import SSD\n",
    "\n",
    "# SSD300の設定\n",
    "ssd_cfg = {\n",
    "    'num_classes': 21,  # 背景クラスを含めた合計クラス数\n",
    "    'input_size': 300,  # 画像の入力サイズ\n",
    "    'bbox_aspect_num': [4, 6, 6, 6, 4, 4],  # 出力するDBoxのアスペクト比の種類\n",
    "    'feature_maps': [38, 19, 10, 5, 3, 1],  # 各sourceの画像サイズ\n",
    "    'steps': [8, 16, 32, 64, 100, 300],  # DBOXの大きさを決める\n",
    "    'min_sizes': [30, 60, 111, 162, 213, 264],  # DBOXの大きさを決める\n",
    "    'max_sizes': [60, 111, 162, 213, 264, 315],  # DBOXの大きさを決める\n",
    "    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],\n",
    "}\n",
    "\n",
    "# SSDネットワークモデル\n",
    "net = SSD(phase=\"inference\", cfg=ssd_cfg)\n",
    "net.eval()\n",
    "\n",
    "# SSDの学習済みの重みを設定\n",
    "net_weights = torch.load('./weights/ssd300_50.pth',\n",
    "                         map_location={'cuda:0': 'cpu'})\n",
    "\n",
    "#net_weights = torch.load('./weights/ssd300_mAP_77.43_v2.pth',\n",
    "#                         map_location={'cuda:0': 'cpu'})\n",
    "\n",
    "net.load_state_dict(net_weights)\n",
    "\n",
    "# GPUが使えるかを確認\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"使用デバイス：\", device)\n",
    "\n",
    "print('ネットワーク設定完了：学習済みの重みをロードしました')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 結果の描画\n",
    "ssd = SSDPredictShow(img_list=train_img_list, dataset=train_dataset, eval_categories=voc_classes,\n",
    "                     net=net, dataconfidence_level=0.6)\n",
    "img_index = 0\n",
    "ssd.show(img_index, \"predict\")\n",
    "ssd.show(img_index, \"ans\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 結果の描画\n",
    "ssd = SSDPredictShow(img_list=val_img_list, dataset=val_dataset, eval_categories=voc_classes,\n",
    "                     net=net, dataconfidence_level=0.6)\n",
    "img_index = 0\n",
    "ssd.show(img_index, \"predict\")\n",
    "ssd.show(img_index, \"ans\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
